package com.whoyx.jiebing.utils.nlp.sentence_encoder;

import ai.djl.Model;
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.StackBatchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;

/**
 * @author stone
 */
@Slf4j
public final class SentenceTransTranslator implements Translator<String, float[]> {

    private DefaultVocabulary vocabulary;
    private BertFullTokenizer tokenizer;

    public SentenceTransTranslator(){};

    @Override
    public Batchifier getBatchifier() {
        return new StackBatchifier();
    }

    @Override
    public void prepare(TranslatorContext ctx) throws IOException {
        Model model = ctx.getModel();
        URL url = model.getArtifact("vocab.txt");
        vocabulary =
                DefaultVocabulary.builder()
                        .optMinFrequency(1)
                        .addFromTextFile(url)
                        .optUnknownToken("[UNK]")
                        .build();
        //    tokenizer = new BertTokenizer();
        tokenizer = new BertFullTokenizer(vocabulary, false);
    }

    @Override
    public float[] processOutput(TranslatorContext ctx, NDList list) {
        NDArray array = null;
        // 下面的排序非固定，每次运行顺序可能会变
        //  input_ids
        //  token_type_ids
        //  attention_mask
        //  token_embeddings: (13, 384) cpu() float32
        //  cls_token_embeddings: (384) cpu() float32
        //  sentence_embedding: (384) cpu() float32

        for (NDArray ndArray : list) {
            String name = ndArray.getName();
            if ("sentence_embedding".equals(name)) {
                array = ndArray;
                break;
            }
        }

        assert array != null;
        return array.toFloatArray();
    }

    @Override
    public NDList processInput(TranslatorContext ctx, String input) {
        List<String> tokens = tokenizer.tokenize(input);
        int maxSequenceLength = 128;
        int minTextLength = 2;
        if (tokens.size() > maxSequenceLength - minTextLength) {
            tokens = tokens.subList(0, maxSequenceLength - minTextLength);
        }
        long[] indices = tokens.stream().mapToLong(vocabulary::getIndex).toArray();
        long[] inputIds = new long[tokens.size() + 2];
        inputIds[0] = vocabulary.getIndex("[CLS]");
        inputIds[inputIds.length - 1] = vocabulary.getIndex("[SEP]");

        System.arraycopy(indices, 0, inputIds, 1, indices.length);

        long[] tokenTypeIds = new long[inputIds.length];
        Arrays.fill(tokenTypeIds, 0);
        long[] attentionMask = new long[inputIds.length];
        Arrays.fill(attentionMask, 1);

        NDManager manager = ctx.getNDManager();

        //        input_features = {'input_ids': input_ids, 'token_type_ids': input_type_ids,
        // 'attention_mask': input_mask}
        //        input_ids
        //        tensor([[  101 [CLS],  2023 this,  7705 framework, 19421 generates,  7861 em,
        //        8270 ##bed,  4667 ##ding,  2015 ##s,  2005 for,  2169 each, 7953 input,  6251
        // sentence,   102 [SEP]]])
        //        token_type_ids
        //        tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
        //        attention_mask
        //        tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

        NDArray indicesArray = manager.create(inputIds);
        indicesArray.setName("input.input_ids");

        //    long[] token_type_ids = new long[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
        NDArray tokenIdsArray = manager.create(tokenTypeIds);
        tokenIdsArray.setName("input.token_type_ids");

        //    long[] attention_mask = new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
        NDArray attentionMaskArray = manager.create(attentionMask);
        attentionMaskArray.setName("input.attention_mask");
        return new NDList(indicesArray, attentionMaskArray);
    }

}
